package go_rpc

import (
	"bufio"
	"bytes"
	"context"
	"encoding/binary"
	"errors"
	"fmt"
	log "github.com/sirupsen/logrus"
	json "gogame/lib"
	"gogame/server/network/tcp"
	"io"
	"math"
	"runtime/debug"
	"time"
)

var (
	ErrInvalidPacket = errors.New("invalid packet")
)

const (
	RpcRequest  = 1 // rpc请求
	RpcResponse = 2 // rpc响应
	RpcNotice   = 3 // rpc通知
)

type RpcInterface interface {
	Handle(conn tcp.ConnChannelInterface)
	GetType() string
	GetName() string
	GetAddress() string
	DoMessage(message *Message)
}

type rpcHandle struct {
	Name   string       // 名字
	Parent RpcInterface // 父类

	//ctx context.Context
	//conn *channel.Channel

	Service *rpcService

	writeChan    chan []byte                       // 发送消息
	readChan     chan *Message                     // 读取消息
	Process      int8                              // 读取消息处理协程数
	numChan      chan uint32                       // id生成
	requestTable map[uint32]map[string]interface{} // 所有的RPC请求的回调
}

// GenerateId  id生成器
func GenerateId() chan uint32 {
	numChan := make(chan uint32)
	num := uint32(0)
	maxNum := uint32(math.MaxUint32)
	go func() {
		for {
			num++
			if num >= maxNum {
				num = 0
			}
			numChan <- num
		}
	}()

	return numChan
}

//  getId 获取一个消息id
func (h *rpcHandle) getId() uint32 {
	return <-h.numChan
}

// WriteLoop 发送数据
func (h *rpcHandle) WriteLoop(ctx context.Context, conn tcp.ConnChannelInterface) {
	for message := range h.writeChan {
		if _, err := conn.Write(message); err != nil {
			h.OnError(ctx, err)
		} else {
			fmt.Printf("%s 转发数据到 %s \n", conn.LocalAddr().String(), conn.RemoteAddr().String())
			h.OnError(ctx, conn.Flush())
		}
	}
}

// ReadLoop 读取数据
func (h *rpcHandle) ReadLoop() {
	for message := range h.readChan {
		func() {
			defer func() {
				if over := recover(); over != nil {
					log.Error(over)
					debug.PrintStack()
				}
			}()
			//h.Parent.DoMessage(message)
			h.DoMessage(message)
		}()
	}
}

// DoMessage 消息处理
func (h *rpcHandle) DoMessage(message *Message) {

	switch message.msgType {
	case RpcRequest:
		var msgList []interface{}
		err := json.Loads(message.msg, &msgList)
		if err != nil {
			fmt.Println("doMessage err-->", err)
			return
		}
		methodName := msgList[0].(string)
		var args []interface{}
		args = msgList[1].([]interface{})
		h.handleRequest(message.msgId, methodName, args...)
	case RpcResponse:
		h.handleResponse(message.msgId, nil, message.msg)
	case RpcNotice:
	}

}

// Run 服务开始
func (h *rpcHandle) Run(ctx context.Context, conn tcp.ConnChannelInterface) {
	h.readChan = make(chan *Message, h.Process)
	h.writeChan = make(chan []byte, 100)
	h.numChan = GenerateId()
	h.requestTable = make(map[uint32]map[string]interface{})

	// 发送数据
	go h.WriteLoop(ctx, conn)

	var i int8
	for i = 0; i < h.Process; i++ {
		go h.ReadLoop()
	}

	h.OnReceive(ctx, conn)

}

// OnReceive 消息过滤、接收
func (h *rpcHandle) OnReceive(ctx context.Context, conn tcp.ConnChannelInterface) {
	var scanner = bufio.NewScanner(conn)
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
		if !atEOF && len(data) >= (PackHeadLength) {
			var dataLength uint64
			dataLength = binary.BigEndian.Uint64(data[MagicNumberLength+MsgIdLength+MsgTypeLength : PackHeadLength])

			packDataLength := PackHeadLength + int(dataLength)
			if len(data) >= packDataLength {
				return packDataLength, data[:packDataLength], nil
			}
		}
		return
	})
	for scanner.Scan() {
		msg := GetMessage()
		if err := h.unpackData(bytes.NewReader(scanner.Bytes()), msg); err != nil {
			// 资源回收
			PutMessage(msg)
			h.OnError(ctx, err)
		} else {
			h.readChan <- msg
		}
	}
	if err := scanner.Err(); err != nil {
		h.OnError(ctx, err)
	}
}

// unpackData 解包
func (h *rpcHandle) unpackData(r io.Reader, msg *Message) error {
	// 包标识
	var magicNumber uint8
	if err := binary.Read(r, binary.BigEndian, &magicNumber); err != nil {
		return err
	}
	if magicNumber != MagicNumber {
		return ErrInvalidPacket
	}

	// 消息类型
	if err := binary.Read(r, binary.BigEndian, &msg.msgType); err != nil {
		return err
	}

	// 消息id
	if err := binary.Read(r, binary.BigEndian, &msg.msgId); err != nil {
		return err
	}

	// 包体长度
	var dataLength uint64
	if err := binary.Read(r, binary.BigEndian, &dataLength); err != nil {
		return err
	}

	// 包体
	msg.msg = make([]byte, dataLength)
	if err := binary.Read(r, binary.BigEndian, &msg.msg); err != nil {
		return err
	}
	return nil
}

// packData 数据打包
func (h *rpcHandle) packData(dataList ...interface{}) ([]byte, error) {
	bufferData := new(bytes.Buffer)
	for _, data := range dataList {
		if err := binary.Write(bufferData, binary.BigEndian, data); err != nil {
			return nil, err
		}
	}
	return bufferData.Bytes(), nil

}

// packResponse 打包响应
func (h *rpcHandle) packResponse(msgId uint32, msgType uint8, result interface{}) ([]byte, error) {
	msg := json.Dumps(result)
	if msg != nil {
		return nil, fmt.Errorf("packResponse error")

	}

	// 包标识 包类型 请求id 包体长度 包体
	packMsg, packMsgErr := h.packData(MagicNumber, msgType, msgId, uint64(len(msg)), msg)
	if packMsgErr != nil {
		return nil, packMsgErr

	}

	//bufferData := new(bytes.Buffer)
	//// 包标识
	//if err := binary.Write(bufferData, binary.BigEndian, MagicNumber); err != nil {
	//	return nil, err
	//}
	//// 包类型
	//if err := binary.Write(bufferData, binary.BigEndian, msgType); err != nil {
	//	return nil, err
	//}
	//// 响应id
	//if err := binary.Write(bufferData, binary.BigEndian, msgId); err != nil {
	//	return nil, err
	//}
	//// 包体长度
	//if err := binary.Write(bufferData, binary.BigEndian, len(msg)); err != nil {
	//	return nil, err
	//}
	//// 包体
	//if err := binary.Write(bufferData, binary.BigEndian, msg); err != nil {
	//	return nil, err
	//}
	//return bufferData.Bytes(), nil

	return packMsg, nil

}

// packRequest 打包请求
func (h *rpcHandle) packRequest(msgId uint32, msgType uint8, methodName string, args ...interface{}) []byte {
	toArgs := []interface{}{methodName, args}
	msg := json.Dumps(toArgs)
	if msg == nil {
		return nil
	}

	// 包标识 包类型 请求id 包体长度 包体
	packMsg, packMsgErr := h.packData(MagicNumber, msgType, msgId, uint64(len(msg)), msg)
	if packMsgErr != nil {
		return nil

	}

	return packMsg
}

func (h *rpcHandle) OnError(ctx context.Context, err error) {

}

func (h *rpcHandle) OnClose(ctx context.Context) {
	close(h.readChan)
	close(h.writeChan)
}

func (h *rpcHandle) Write(msg []byte) error {

	select {
	case h.writeChan <- msg:
		// Successful writing
	default:
		return errors.New("writeChan full~~")
	}

	return nil
}

func (h *rpcHandle) addRequestTable(msgId uint32, result interface{}, kwArgs map[string]interface{}) {
	h.requestTable[msgId] = map[string]interface{}{
		"result": result,
		"time":   time.Now().Unix(),
		"kwArgs": kwArgs,
	}
}

func (h *rpcHandle) popRequestTable(msgId uint32) map[string]interface{} {
	res, ok := h.requestTable[msgId]
	if !ok {
		res = nil
	} else {
		delete(h.requestTable, msgId)
	}

	return res
}

// Call 等待返回 阻塞
func (h *rpcHandle) Call(methodName string, args ...interface{}) []byte {
	var msgId = h.getId()
	//if msg := h.packRequest(msgId, RpcRequest, methodName, args...); err != nil {
	//	fmt.Println(err)
	//}
	msg := h.packRequest(msgId, RpcRequest, methodName, args...)

	retChan := make(chan interface{}, 0)
	if msg != nil {
		h.addRequestTable(msgId, retChan, nil)
		_ = h.Write(msg)
	}

	ret := <-retChan
	defer close(retChan)

	return ret.([]byte)

}

// Notice 通知 非阻塞
func (h *rpcHandle) Notice(methodName string, args ...interface{}) {
	var msgId = h.getId()
	//if msg, err := h.packRequest(msgId, RpcNotice, methodName, args...); err != nil {
	//	fmt.Println(err)
	//}
	msg := h.packRequest(msgId, RpcNotice, methodName, args...)

	if msg != nil {
		_ = h.Write(msg)
	}

}

// CallBack 回调 非阻塞
func (h *rpcHandle) CallBack(methodName string, kwArgs map[string]interface{}, args ...interface{}) {
	var msgId = h.getId()

	lastArgsIdx := len(args) - 1
	lastArg := args[lastArgsIdx]
	_, callable := lastArg.(func(msg interface{}, err error, rid string, trans bool))
	if callable {
		args = args[:lastArgsIdx]
	}

	msg := h.packRequest(msgId, RpcRequest, methodName, args...)

	if msg != nil {
		h.addRequestTable(msgId, lastArg, kwArgs)
		_ = h.Write(msg)
	}

}

// Response 响应
func (h *rpcHandle) Response(msg interface{}) {
	var msgId = h.getId()
	packMsg, _ := h.packResponse(msgId, RpcResponse, msg)

	if packMsg != nil {
		_ = h.Write(packMsg)
	}

}

// handleResponse 处理响应
func (h *rpcHandle) handleResponse(msgId uint32, err error, ret interface{}) {
	asyncResult := h.popRequestTable(msgId)
	if asyncResult == nil {
		fmt.Printf("%d 对应的result不存在\n", msgId)
		return
	}

	if err != nil {
		fmt.Printf("handleResponse err!! --> %s \n", err)
		return
	}

	result := asyncResult["result"]
	kwArgs := asyncResult["kwArgs"].(map[string]interface{})
	switch result.(type) {
	// callback 回调
	case func(msg interface{}, err error, rid string, trans bool):
		var msg map[string]interface{}
		_ = json.Loads(ret.([]byte), &msg)
		callableFunc := result.(func(msg interface{}, err error, rid string, trans bool))
		callableFunc(msg, err, kwArgs["rid"].(string), kwArgs["trans"].(bool))
	// call 返回结果
	case chan interface{}:
		retChan := result.(chan interface{})
		retChan <- ret
	}
}

// handleRequest 处理请求
func (h *rpcHandle) handleRequest(msgId uint32, methodName string, args ...interface{}) {
	//time.Sleep(time.Second * 10)
	result := map[string]interface{}{
		"s": 1,
	}
	packMsg, _ := h.packResponse(msgId, RpcResponse, result)

	if packMsg != nil {
		_ = h.Write(packMsg)
	}
}
